{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PX4QQ3CFW1TF"
   },
   "source": [
    "# **MultiRC** - Multihop multiple-choice question answering dataset\n",
    "\n",
    "# **Model** - NER-based QA\n",
    "\n",
    "**APPROACH** -\n",
    "\n",
    "**Dataset Preparation**\n",
    "1. Concatenate paragraph + question + answers into a single context\n",
    "2. Use discriminatory tags for each of- paragraph(P), question(Q), correct answer(C), wrong answer(W) and inside tags(I)\n",
    "3. Now, the dataset is a CSV file with the following structure-\n",
    "\n",
    "\\<ID, TOKEN, TAG\\>\n",
    "\n",
    "where,\n",
    "\n",
    "ID- unique for every (paragraph,question,answers) combination\n",
    "\n",
    "TOKEN- paragraph + question + options concatenated  tokenized\n",
    "\n",
    "TAG - pre-determned tag for every portion in the context\n",
    "\n",
    "\n",
    "**Model Preparation**\n",
    "\n",
    "4. Train the model to learn this variation of BIO tagging\n",
    "\n",
    "**Evaluation Preparation**\n",
    "\n",
    "5. Evaluate model's performnance against expected results- tagging the correct answer as CI tags and wrong answer as WI tags."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ulTtzBt1xM9G"
   },
   "source": [
    "# NOTE : Search \"TODO\" to make changes for original/sampled data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9ddbuTPIUi7h"
   },
   "outputs": [],
   "source": [
    "%reload_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jwPQNGAnWbHc"
   },
   "source": [
    "# Mounting data\n",
    "1. train.csv - training set\n",
    "2. dev.csv - testing set\n",
    "\n",
    "Note- We are using validation set as our test set since the MultiRC test set is not publicly available and it's not possible to verify labels and analyse model performance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 122
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 23053,
     "status": "ok",
     "timestamp": 1588561961240,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "oVMX9PY6U-cc",
    "outputId": "54dc48da-b646-4a96-b36d-a9fac6080c84"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly\n",
      "\n",
      "Enter your authorization code:\n",
      "··········\n",
      "Mounted at /content/gdrive\n"
     ]
    }
   ],
   "source": [
    "from google.colab import drive\n",
    "drive.mount('/content/gdrive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hB8x1XAPhB2D"
   },
   "outputs": [],
   "source": [
    "PARENT_DIR = \"/content/gdrive/My Drive/MultiRC_NER\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 24376,
     "status": "ok",
     "timestamp": 1588561962590,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "UutbeUmMVP_P",
    "outputId": "5d75b3de-bf63-4eda-a175-76881306847d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dev.csv\t\tdev_v2.csv     qa\t  train_sample.csv  vocab.txt\n",
      "dev_sample.csv\tparsing_v2.py  train.csv  train_v2.csv\n"
     ]
    }
   ],
   "source": [
    "!ls \"/content/gdrive/My Drive/MultiRC_NER/data\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WaD3PslCWEVK"
   },
   "source": [
    "# Requirements"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 972
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 36740,
     "status": "ok",
     "timestamp": 1588561974972,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "WyQ9p51QWDHq",
    "outputId": "e9fa85f4-8661-40c8-9fce-a9e67edb270d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting seqeval\n",
      "  Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz\n",
      "Requirement already satisfied: numpy>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from seqeval) (1.18.3)\n",
      "Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval) (2.3.1)\n",
      "Requirement already satisfied: pyyaml in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval) (3.13)\n",
      "Requirement already satisfied: scipy>=0.14 in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval) (1.4.1)\n",
      "Requirement already satisfied: h5py in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval) (2.10.0)\n",
      "Requirement already satisfied: keras-applications>=1.0.6 in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval) (1.0.8)\n",
      "Requirement already satisfied: keras-preprocessing>=1.0.5 in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval) (1.1.0)\n",
      "Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.6/dist-packages (from Keras>=2.2.4->seqeval) (1.12.0)\n",
      "Building wheels for collected packages: seqeval\n",
      "  Building wheel for seqeval (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for seqeval: filename=seqeval-0.0.12-cp36-none-any.whl size=7424 sha256=d4d4e3ff2be274a64f32527e57ce9f00cfadb5776fb802968653aebca2f26cdd\n",
      "  Stored in directory: /root/.cache/pip/wheels/4f/32/0a/df3b340a82583566975377d65e724895b3fad101a3fb729f68\n",
      "Successfully built seqeval\n",
      "Installing collected packages: seqeval\n",
      "Successfully installed seqeval-0.0.12\n",
      "Collecting transformers\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/a3/78/92cedda05552398352ed9784908b834ee32a0bd071a9b32de287327370b7/transformers-2.8.0-py3-none-any.whl (563kB)\n",
      "\u001b[K     |████████████████████████████████| 573kB 3.4MB/s \n",
      "\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.12.47)\n",
      "Collecting sentencepiece\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/98/2c/8df20f3ac6c22ac224fff307ebc102818206c53fc454ecd37d8ac2060df5/sentencepiece-0.1.86-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
      "\u001b[K     |████████████████████████████████| 1.0MB 9.8MB/s \n",
      "\u001b[?25hRequirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n",
      "Collecting tokenizers==0.5.2\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/d1/3f/73c881ea4723e43c1e9acf317cf407fab3a278daab3a69c98dcac511c04f/tokenizers-0.5.2-cp36-cp36m-manylinux1_x86_64.whl (3.7MB)\n",
      "\u001b[K     |████████████████████████████████| 3.7MB 20.4MB/s \n",
      "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n",
      "Collecting sacremoses\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n",
      "\u001b[K     |████████████████████████████████| 890kB 38.1MB/s \n",
      "\u001b[?25hRequirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)\n",
      "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.3)\n",
      "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.38.0)\n",
      "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)\n",
      "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.5)\n",
      "Requirement already satisfied: s3transfer<0.4.0,>=0.3.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.3.3)\n",
      "Requirement already satisfied: botocore<1.16.0,>=1.15.47 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.15.47)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n",
      "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)\n",
      "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.1)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.4.5.1)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.9)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
      "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.47->boto3->transformers) (2.8.1)\n",
      "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.16.0,>=1.15.47->boto3->transformers) (0.15.2)\n",
      "Building wheels for collected packages: sacremoses\n",
      "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893260 sha256=d4c03b8434783a58ad182543f900d84efb12e06c1f0ee1cb3938bcf27737ee3c\n",
      "  Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n",
      "Successfully built sacremoses\n",
      "Installing collected packages: sentencepiece, tokenizers, sacremoses, transformers\n",
      "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.86 tokenizers-0.5.2 transformers-2.8.0\n"
     ]
    }
   ],
   "source": [
    "!pip install seqeval\n",
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RcUnNx1WUi7l"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import math\n",
    "import numpy as np\n",
    "from seqeval.metrics import f1_score\n",
    "from seqeval.metrics import classification_report,accuracy_score,f1_score\n",
    "import torch.nn.functional as F"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 42016,
     "status": "ok",
     "timestamp": 1588561980269,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "nBU6k3dQUi7p",
    "outputId": "73d3d494-a32c-43e0-9803-3cf86ca2d389"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import os\n",
    "from tqdm import tqdm,trange\n",
    "from torch.optim import Adam\n",
    "from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from sklearn.model_selection import train_test_split\n",
    "from transformers import BertTokenizer, BertConfig\n",
    "from transformers import BertForTokenClassification, AdamW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 153
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 43182,
     "status": "ok",
     "timestamp": 1588561981449,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "-sUo6M6fUi7t",
    "outputId": "f735694b-bd6a-41b5-b9c4-75d71619ffba"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Keras                    2.3.1          \n",
      "Keras-Applications       1.0.8          \n",
      "Keras-Preprocessing      1.1.0          \n",
      "torch                    1.5.0+cu101    \n",
      "torchsummary             1.5.1          \n",
      "torchtext                0.3.1          \n",
      "torchvision              0.6.0+cu101    \n",
      "transformers             2.8.0          \n"
     ]
    }
   ],
   "source": [
    "# Check library version\n",
    "!pip list | grep -E 'transformers|torch|Keras'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "8HmaFwPpUi7w"
   },
   "source": [
    "This notebook works with env:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "g1jIVHbJUi7x"
   },
   "source": [
    "- Keras                2.3.1                 \n",
    "- torch                1.1.0                 \n",
    "- transformers         2.2.0      "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BTn36HHiUi7x"
   },
   "source": [
    "# Introduction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3g3rC1GUUi7y"
   },
   "source": [
    "NER-based QA with BERT, steps:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kTKLxKvLUi7z"
   },
   "source": [
    "- Load and preprocess data\n",
    "- Parser data\n",
    "- Make training data\n",
    "- Train model\n",
    "- Evaluate result\n",
    "- Predict result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4X_eNNqGUi75"
   },
   "source": [
    "## Load data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "04kRhJ2mUi76"
   },
   "source": [
    "**Load CSV data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "7sZnV5gGUi76"
   },
   "outputs": [],
   "source": [
    "data_path = PARENT_DIR + \"/data\" "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ctpiyXm9Ui7-"
   },
   "outputs": [],
   "source": [
    "# TODO: \"train.csv\" - original, \"train_sample.csv\" - sampled file(1/100th data)\n",
    "train_file_address = PARENT_DIR + \"/data/train_v2.csv\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Z1j-H3sKUi8A"
   },
   "outputs": [],
   "source": [
    "# Fillna method can make same sentence with same sentence name\n",
    "# NOTE - encoding latin1 => utf-8\n",
    "df_data = pd.read_csv(train_file_address,sep=\",\",encoding=\"utf-8\").fillna(method='ffill')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 44577,
     "status": "ok",
     "timestamp": 1588561982884,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "HuoBFgbsUi8D",
    "outputId": "30b91634-081c-477e-9e2b-869b6bf46644"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Index(['ID', 'TOKEN', 'TAG'], dtype='object')"
      ]
     },
     "execution_count": 12,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_data.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 669
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 44562,
     "status": "ok",
     "timestamp": 1588561982885,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "Gf7dOgcAUi8H",
    "outputId": "04845ea9-8e47-4ad4-cdca-ec045fd25c68"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>ID</th>\n",
       "      <th>TOKEN</th>\n",
       "      <th>TAG</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>Animated</td>\n",
       "      <td>P</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>history</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>of</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>the</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>US</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>.</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>1</td>\n",
       "      <td>Of</td>\n",
       "      <td>P</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>1</td>\n",
       "      <td>course</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1</td>\n",
       "      <td>the</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9</th>\n",
       "      <td>1</td>\n",
       "      <td>cartoon</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>10</th>\n",
       "      <td>1</td>\n",
       "      <td>is</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>11</th>\n",
       "      <td>1</td>\n",
       "      <td>highly</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>1</td>\n",
       "      <td>oversimplified</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>1</td>\n",
       "      <td>and</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>14</th>\n",
       "      <td>1</td>\n",
       "      <td>most</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>15</th>\n",
       "      <td>1</td>\n",
       "      <td>critics</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>16</th>\n",
       "      <td>1</td>\n",
       "      <td>consider</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17</th>\n",
       "      <td>1</td>\n",
       "      <td>it</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18</th>\n",
       "      <td>1</td>\n",
       "      <td>one</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>1</td>\n",
       "      <td>of</td>\n",
       "      <td>I</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "    ID           TOKEN TAG\n",
       "0    1        Animated   P\n",
       "1    1         history   I\n",
       "2    1              of   I\n",
       "3    1             the   I\n",
       "4    1              US   I\n",
       "5    1               .   I\n",
       "6    1              Of   P\n",
       "7    1          course   I\n",
       "8    1             the   I\n",
       "9    1         cartoon   I\n",
       "10   1              is   I\n",
       "11   1          highly   I\n",
       "12   1  oversimplified   I\n",
       "13   1             and   I\n",
       "14   1            most   I\n",
       "15   1         critics   I\n",
       "16   1        consider   I\n",
       "17   1              it   I\n",
       "18   1             one   I\n",
       "19   1              of   I"
      ]
     },
     "execution_count": 13,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_data.head(n=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "qkB6hAWKUi8L"
   },
   "source": [
    "**TAG categories**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 44678,
     "status": "ok",
     "timestamp": 1588561983018,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "-jE4_dKAUi8L",
    "outputId": "bc221769-c3dd-4d65-d4ab-53ead0c47a23"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['P', 'I', 'Q', 'C', 'W'], dtype=object)"
      ]
     },
     "execution_count": 14,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_data.TAG.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 44661,
     "status": "ok",
     "timestamp": 1588561983019,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "PBwzJOIFUi8S",
    "outputId": "41483a0a-2548-452f-f6d1-2f1951f7e75c"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5131, 24067, 5)"
      ]
     },
     "execution_count": 15,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Data summary\n",
    "df_data['ID'].nunique(), df_data.TOKEN.nunique(), df_data.TAG.nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 44780,
     "status": "ok",
     "timestamp": 1588561983154,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "iLeBXoekUi8U",
    "outputId": "37744dc5-a032-4788-fa8d-33bd96f8eb74"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "I    1541258\n",
       "P      68462\n",
       "W      15218\n",
       "C      12025\n",
       "Q       5131\n",
       "Name: TAG, dtype: int64"
      ]
     },
     "execution_count": 16,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# TAG distribution\n",
    "df_data.TAG.value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "vKbYp0YaUi8X"
   },
   "source": [
    "### Explain TAG\n",
    "As shown and explained above, there are 4 distinct tags, one each for- Paragraph, Question, Correct answer and Wrong answer\n",
    "- P: Paragraph sentence begin, word at the first  position\n",
    "- Q: Question sentence begin, word at the first  position\n",
    "- C: Correct answer sentence begin, word at the first  position\n",
    "- W: Wrong answer sentence begin, word at the first  position\n",
    "- I: inside, word not at the first position, for sentences"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SmtZkBnFUi8e"
   },
   "source": [
    "## Parser data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SRHouNE4Ui8e"
   },
   "source": [
    "**Parser data into document structure**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HE87UalmUi8f"
   },
   "outputs": [],
   "source": [
    "class SentenceGetter(object):\n",
    "    \n",
    "    def __init__(self, data):\n",
    "        self.n_sent = 1\n",
    "        self.data = data\n",
    "        self.empty = False\n",
    "        agg_func = lambda s: [(w, t) for w, t in zip(s[\"TOKEN\"].values.tolist(),\n",
    "                                                           s[\"TAG\"].values.tolist())]\n",
    "        self.grouped = self.data.groupby(\"ID\").apply(agg_func)\n",
    "        self.sentences = [s for s in self.grouped]\n",
    "    \n",
    "    def get_next(self):\n",
    "        try:\n",
    "            s = self.grouped[\"Sentence: {}\".format(self.n_sent)]\n",
    "            self.n_sent += 1\n",
    "            return s\n",
    "        except:\n",
    "            return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Eok4eiP9Ui8i"
   },
   "outputs": [],
   "source": [
    "# Get full document data structure\n",
    "getter = SentenceGetter(df_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 46094,
     "status": "ok",
     "timestamp": 1588561984499,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "gLrGqAXxUi8k",
    "outputId": "1137086f-e220-48ee-aaa5-e29e2d7f4c5c"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Animated',\n",
       " 'history',\n",
       " 'of',\n",
       " 'the',\n",
       " 'US',\n",
       " '.',\n",
       " 'Of',\n",
       " 'course',\n",
       " 'the',\n",
       " 'cartoon',\n",
       " 'is',\n",
       " 'highly',\n",
       " 'oversimplified',\n",
       " 'and',\n",
       " 'most',\n",
       " 'critics',\n",
       " 'consider',\n",
       " 'it',\n",
       " 'one',\n",
       " 'of',\n",
       " 'the',\n",
       " 'weakest',\n",
       " 'parts',\n",
       " 'of',\n",
       " 'the',\n",
       " 'film',\n",
       " '.',\n",
       " 'But',\n",
       " 'it',\n",
       " 'makes',\n",
       " 'a',\n",
       " 'valid',\n",
       " 'claim',\n",
       " 'which',\n",
       " 'you',\n",
       " 'ignore',\n",
       " 'entirely:',\n",
       " 'That',\n",
       " 'the',\n",
       " 'strategy',\n",
       " 'to',\n",
       " 'promote',\n",
       " 'gun',\n",
       " 'rights',\n",
       " 'for',\n",
       " 'white',\n",
       " 'people',\n",
       " 'and',\n",
       " 'to',\n",
       " 'outlaw',\n",
       " 'gun',\n",
       " 'possession',\n",
       " 'by',\n",
       " 'black',\n",
       " 'people',\n",
       " 'was',\n",
       " 'a',\n",
       " 'way',\n",
       " 'to',\n",
       " 'uphold',\n",
       " 'racism',\n",
       " 'without',\n",
       " 'letting',\n",
       " 'an',\n",
       " 'openly',\n",
       " 'terrorist',\n",
       " 'organization',\n",
       " 'like',\n",
       " 'the',\n",
       " 'KKK',\n",
       " 'flourish',\n",
       " '.',\n",
       " 'Did',\n",
       " 'the',\n",
       " '19th',\n",
       " 'century',\n",
       " 'NRA',\n",
       " 'in',\n",
       " 'the',\n",
       " 'southern',\n",
       " 'states',\n",
       " 'promote',\n",
       " 'gun',\n",
       " 'rights',\n",
       " 'for',\n",
       " 'black',\n",
       " 'people',\n",
       " '?',\n",
       " 'I',\n",
       " 'highly',\n",
       " 'doubt',\n",
       " 'it',\n",
       " '.',\n",
       " 'But',\n",
       " 'if',\n",
       " 'they',\n",
       " \"didn't\",\n",
       " 'one',\n",
       " 'of',\n",
       " 'their',\n",
       " 'functions',\n",
       " 'was',\n",
       " 'to',\n",
       " 'continue',\n",
       " 'the',\n",
       " 'racism',\n",
       " 'of',\n",
       " 'the',\n",
       " 'KKK',\n",
       " '.',\n",
       " 'This',\n",
       " 'is',\n",
       " 'the',\n",
       " 'key',\n",
       " 'message',\n",
       " 'of',\n",
       " 'this',\n",
       " 'part',\n",
       " 'of',\n",
       " 'the',\n",
       " 'animation',\n",
       " 'which',\n",
       " 'is',\n",
       " 'again',\n",
       " 'being',\n",
       " 'ignored',\n",
       " 'by',\n",
       " 'its',\n",
       " 'critics',\n",
       " '.',\n",
       " 'Buell',\n",
       " 'shooting',\n",
       " 'in',\n",
       " 'Flint',\n",
       " '.',\n",
       " 'You',\n",
       " 'write:',\n",
       " 'Fact:',\n",
       " 'The',\n",
       " 'little',\n",
       " 'boy',\n",
       " 'was',\n",
       " 'the',\n",
       " 'class',\n",
       " 'thug',\n",
       " 'already',\n",
       " 'suspended',\n",
       " 'from',\n",
       " 'school',\n",
       " 'for',\n",
       " 'stabbing',\n",
       " 'another',\n",
       " 'kid',\n",
       " 'with',\n",
       " 'a',\n",
       " 'pencil',\n",
       " 'and',\n",
       " 'had',\n",
       " 'fought',\n",
       " 'with',\n",
       " 'Kayla',\n",
       " 'the',\n",
       " 'day',\n",
       " 'before',\n",
       " '.',\n",
       " 'This',\n",
       " 'characterization',\n",
       " 'of',\n",
       " 'a',\n",
       " 'six-year-old',\n",
       " 'as',\n",
       " 'a',\n",
       " 'pencil-stabbing',\n",
       " 'thug',\n",
       " 'is',\n",
       " 'exactly',\n",
       " 'the',\n",
       " 'kind',\n",
       " 'of',\n",
       " 'hysteria',\n",
       " 'that',\n",
       " \"Moore's\",\n",
       " 'film',\n",
       " 'warns',\n",
       " 'against',\n",
       " '.',\n",
       " 'It',\n",
       " 'is',\n",
       " 'the',\n",
       " 'typical',\n",
       " 'right-wing',\n",
       " 'reaction',\n",
       " 'which',\n",
       " 'looks',\n",
       " 'for',\n",
       " 'simple',\n",
       " 'answers',\n",
       " 'that',\n",
       " 'do',\n",
       " 'not',\n",
       " 'contradict',\n",
       " 'the',\n",
       " 'Republican',\n",
       " 'mindset',\n",
       " '.',\n",
       " 'The',\n",
       " 'kid',\n",
       " 'was',\n",
       " 'a',\n",
       " 'little',\n",
       " 'bastard',\n",
       " 'and',\n",
       " 'the',\n",
       " 'parents',\n",
       " 'were',\n",
       " 'involved',\n",
       " 'in',\n",
       " 'drugs',\n",
       " '--',\n",
       " 'case',\n",
       " 'closed',\n",
       " '.',\n",
       " 'But',\n",
       " 'why',\n",
       " 'do',\n",
       " 'people',\n",
       " 'deal',\n",
       " 'with',\n",
       " 'drugs',\n",
       " '?',\n",
       " 'Because',\n",
       " \"it's\",\n",
       " 'so',\n",
       " 'much',\n",
       " 'fun',\n",
       " 'to',\n",
       " 'do',\n",
       " 'so',\n",
       " '?',\n",
       " 'It',\n",
       " 'is',\n",
       " 'by',\n",
       " 'now',\n",
       " 'well',\n",
       " 'documented',\n",
       " 'that',\n",
       " 'the',\n",
       " 'CIA',\n",
       " 'tolerated',\n",
       " 'crack',\n",
       " 'sales',\n",
       " 'in',\n",
       " 'US',\n",
       " 'cities',\n",
       " 'to',\n",
       " 'fund',\n",
       " 'the',\n",
       " 'operation',\n",
       " 'of',\n",
       " 'South',\n",
       " 'American',\n",
       " 'contras',\n",
       " 'It',\n",
       " 'is',\n",
       " 'equally',\n",
       " 'well',\n",
       " 'known',\n",
       " 'that',\n",
       " 'the',\n",
       " 'so-called',\n",
       " 'war',\n",
       " 'on',\n",
       " 'drugs',\n",
       " 'begun',\n",
       " 'under',\n",
       " 'the',\n",
       " 'Nixon',\n",
       " 'administration',\n",
       " 'is',\n",
       " 'a',\n",
       " 'failure',\n",
       " 'which',\n",
       " 'has',\n",
       " 'cost',\n",
       " 'hundreds',\n",
       " 'of',\n",
       " 'billions',\n",
       " 'and',\n",
       " 'made',\n",
       " 'America',\n",
       " 'the',\n",
       " 'world',\n",
       " 'leader',\n",
       " 'in',\n",
       " 'prison',\n",
       " 'population',\n",
       " '(both',\n",
       " 'in',\n",
       " 'relative',\n",
       " 'and',\n",
       " 'absolute',\n",
       " 'numbers)',\n",
       " '.',\n",
       " 'Does',\n",
       " 'the',\n",
       " 'author',\n",
       " 'claim',\n",
       " 'the',\n",
       " 'animated',\n",
       " 'films',\n",
       " 'message',\n",
       " 'is',\n",
       " 'that',\n",
       " 'the',\n",
       " 'NRA',\n",
       " 'upholds',\n",
       " 'racism',\n",
       " '?',\n",
       " 'Yes',\n",
       " '.',\n",
       " 'Uphold',\n",
       " 'and',\n",
       " 'continue',\n",
       " '.',\n",
       " 'No',\n",
       " '.']"
      ]
     },
     "execution_count": 19,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Get sentence data\n",
    "sentences = [[s[0] for s in sent] for sent in getter.sentences]\n",
    "sentences[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 54
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 46177,
     "status": "ok",
     "timestamp": 1588561984595,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "dQrZXSEcUi8p",
    "outputId": "b9c39343-9008-4489-ea2b-c8588e584173"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['P', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'P', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'Q', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'I', 'C', 'I', 'C', 'I', 'I', 'I', 'W', 'I']\n"
     ]
    }
   ],
   "source": [
    "# Get TAG labels data\n",
    "labels = [[s[1] for s in sent] for sent in getter.sentences]\n",
    "print(labels[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kMCiMsc7Ui8u"
   },
   "source": [
    "**Convert TAG name into index for training**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "qpRLGoGuUi8v"
   },
   "outputs": [],
   "source": [
    "tags_vals = list(set(df_data[\"TAG\"].values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ExdzObDaUi8y"
   },
   "outputs": [],
   "source": [
    "# Add X  label for word piece support\n",
    "# Add [CLS] and [SEP] as BERT need\n",
    "tags_vals.append('X')\n",
    "tags_vals.append('[CLS]')\n",
    "tags_vals.append('[SEP]')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_AG-a8kcUi81"
   },
   "outputs": [],
   "source": [
    "tags_vals = set(tags_vals)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 46546,
     "status": "ok",
     "timestamp": 1588561985003,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "JxBhynbKUi83",
    "outputId": "4a114acb-0012-4411-bc22-446c839b0aab"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'C', 'I', 'P', 'Q', 'W', 'X', '[CLS]', '[SEP]'}"
      ]
     },
     "execution_count": 24,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tags_vals"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5kC5wynXUi87"
   },
   "outputs": [],
   "source": [
    "# Set a dict for mapping id to tag name\n",
    "#tag2idx = {t: i for i, t in enumerate(tags_vals)}\n",
    "\n",
    "# Manual definition\n",
    "tag2idx={'C': 2,\n",
    " 'I': 3,\n",
    " 'P': 0,\n",
    " 'Q': 1,\n",
    " 'W': 4,\n",
    " 'X':5,\n",
    " '[CLS]':6,\n",
    " '[SEP]':7}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 46526,
     "status": "ok",
     "timestamp": 1588561985004,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "q0SyPToeUi8_",
    "outputId": "63cfa67a-3e64-49c8-9449-f89ac1f2ce65"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'C': 2, 'I': 3, 'P': 0, 'Q': 1, 'W': 4, 'X': 5, '[CLS]': 6, '[SEP]': 7}"
      ]
     },
     "execution_count": 26,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tag2idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "x_DHev6tUi9C"
   },
   "outputs": [],
   "source": [
    "# Mapping index to name (reverse)\n",
    "tag2name={tag2idx[key] : key for key in tag2idx.keys()}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 46506,
     "status": "ok",
     "timestamp": 1588561985004,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "FEkTJIkBUi9J",
    "outputId": "4e881ed8-dff4-41ef-f018-a733cc6db68f"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: 'P', 1: 'Q', 2: 'C', 3: 'I', 4: 'W', 5: 'X', 6: '[CLS]', 7: '[SEP]'}"
      ]
     },
     "execution_count": 28,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tag2name"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-MaQoRbYUi9N"
   },
   "source": [
    "## Preparation - Training Data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ORMHJaGvUi9N"
   },
   "source": [
    "Raw data => trainable data for BERT, including:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tfRMtXayUi9O"
   },
   "source": [
    "- GPU environment\n",
    "- Loading tokenizer and tokenize\n",
    "- Set 3 embeddings - token, mask word, segmentation\n",
    "- Use the TRAIN and VALIDATION set"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "0oUKNJBMUi9R"
   },
   "source": [
    "**Setting-up GPU environment**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "9vtq3HojUi9S"
   },
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "n_gpu = torch.cuda.device_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 46591,
     "status": "ok",
     "timestamp": 1588561985109,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "kuu0D9gaUi9U",
    "outputId": "a35ff739-dc64-4a9d-dc57-caf1a9bd8d97"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 30,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "n_gpu"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Il3J1IZ5Ui9X"
   },
   "source": [
    "### Loading Tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NE6BEvm7Ui9Y"
   },
   "source": [
    "Downloading the tokenizer file into GDrive folder first :\n",
    "- [vocab.txt](https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wvvpxN3RUi9Y"
   },
   "outputs": [],
   "source": [
    "vocabulary = PARENT_DIR + \"/data/vocab.txt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RKRuXFMiUi9a"
   },
   "outputs": [],
   "source": [
    "# Length of the sentence = 384 (dataset analysis- paragraph + question + answers = ~ 350, generally.)\n",
    "# CAUTION - should be less than 512\n",
    "# TODO : try with increased length\n",
    "max_len  = 384"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ThgQqj6wUi9d"
   },
   "outputs": [],
   "source": [
    "# load tokenizer, with manual file address or pretrained address\n",
    "tokenizer=BertTokenizer(vocab_file=vocabulary,do_lower_case=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kmXz0B4oUi9k"
   },
   "source": [
    "**Tokenizer text**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5ESveNcEUi9l"
   },
   "source": [
    "- In hunggieface for bert, when come across OOV, will word piece the word\n",
    "- We need to adjust the labels base on the tokenize result, “##abc” need to set label \"X\" \n",
    "- Need to set \"[CLS]\" at front and \"[SEP]\" at the end, as what the paper do, [BERT indexer should add [CLS] and [SEP] tokens](https://github.com/allenai/allennlp/issues/2141)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 377
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 98554,
     "status": "ok",
     "timestamp": 1588562037110,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "LamYC4kWUi9m",
    "outputId": "ac49e0b1-1667-4368-ba23-7494fb8b14e0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "No.0,len:378\n",
      "texts:[CLS] Animated history of the US . Of course the cartoon is highly overs ##im ##plified and most critics consider it one of the weak ##est parts of the film . But it makes a valid claim which you ignore entirely : That the strategy to promote gun rights for white people and to out ##law gun possession by black people was a way to up ##hold racism without letting an openly terrorist organization like the K ##K ##K flourish . Did the 19th century N ##RA in the southern states promote gun rights for black people ? I highly doubt it . But if they didn ' t one of their functions was to continue the racism of the K ##K ##K . This is the key message of this part of the animation which is again being ignored by its critics . B ##uel ##l shooting in Flint . You write : F ##act : The little boy was the class th ##ug already suspended from school for stabbing another kid with a pencil and had fought with Kay ##la the day before . This characterization of a six - year - old as a pencil - stabbing th ##ug is exactly the kind of h ##yster ##ia that Moore ' s film warns against . It is the typical right - wing reaction which looks for simple answers that do not con ##tra ##dict the Republican minds ##et . The kid was a little bastard and the parents were involved in drugs - - case closed . But why do people deal with drugs ? Because it ' s so much fun to do so ? It is by now well documented that the CIA tolerate ##d crack sales in US cities to fund the operation of South American con ##tras It is equally well known that the so - called war on drugs begun under the Nixon administration is a failure which has cost hundreds of billion ##s and made America the world leader in prison population ( both in relative and absolute numbers ) . Does the author claim the animated films message is that the N ##RA up ##hold ##s racism ? Yes . Up ##hold and continue . No . [SEP]\n",
      "No.0,len:378\n",
      "lables:[CLS] P I I I I I P I I I I I I X X I I I I I I I I I X I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I I I I I I X X I I P I I I I X I I I I I I I I I I I P I I I I P I I I X X I I I I I I I I I I I I X X I P I I I I I I I I I I I I I I I I I I I P X X I I I I P I X I X X I I I I I I I X I I I I I I I I I I I I I I I I X I I I I P I I I I X X X X I I I X X I X I I I I I I X X I I X X I I I I P I I I I X X I I I I I I I I I I X X I I I X I P I I I I I I I I I I I I I X I I I P I I I I I I I P I X X I I I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I X X I I I I I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I X I Q I I I I I I I I I I I X I X X I I C I C X I I I W I [SEP]\n",
      "No.1,len:430\n",
      "texts:[CLS] Animated history of the US . Of course the cartoon is highly overs ##im ##plified and most critics consider it one of the weak ##est parts of the film . But it makes a valid claim which you ignore entirely : That the strategy to promote gun rights for white people and to out ##law gun possession by black people was a way to up ##hold racism without letting an openly terrorist organization like the K ##K ##K flourish . Did the 19th century N ##RA in the southern states promote gun rights for black people ? I highly doubt it . But if they didn ' t one of their functions was to continue the racism of the K ##K ##K . This is the key message of this part of the animation which is again being ignored by its critics . B ##uel ##l shooting in Flint . You write : F ##act : The little boy was the class th ##ug already suspended from school for stabbing another kid with a pencil and had fought with Kay ##la the day before . This characterization of a six - year - old as a pencil - stabbing th ##ug is exactly the kind of h ##yster ##ia that Moore ' s film warns against . It is the typical right - wing reaction which looks for simple answers that do not con ##tra ##dict the Republican minds ##et . The kid was a little bastard and the parents were involved in drugs - - case closed . But why do people deal with drugs ? Because it ' s so much fun to do so ? It is by now well documented that the CIA tolerate ##d crack sales in US cities to fund the operation of South American con ##tras It is equally well known that the so - called war on drugs begun under the Nixon administration is a failure which has cost hundreds of billion ##s and made America the world leader in prison population ( both in relative and absolute numbers ) . Which key message ( s ) do ( es ) this passage say the critics ignored ? The strategy to promote gun rights for white people while out ##law ##ing it for black people allowed r ##ac ##isi ##m to continue without allowing to K ##K ##K to flourish . That it ant ##agon ##ized the K ##K ##K . That the K ##K ##K was a terrorist organization . The strategy to promote the K ##K ##K . [SEP]\n",
      "No.1,len:430\n",
      "lables:[CLS] P I I I I I P I I I I I I X X I I I I I I I I I X I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I I I I I I X X I I P I I I I X I I I I I I I I I I I P I I I I P I I I X X I I I I I I I I I I I I X X I P I I I I I I I I I I I I I I I I I I I P X X I I I I P I X I X X I I I I I I I X I I I I I I I I I I I I I I I I X I I I I P I I I I X X X X I I I X X I X I I I I I I X X I I X X I I I I P I I I I X X I I I I I I I I I I X X I I I X I P I I I I I I I I I I I I I X I I I P I I I I I I I P I X X I I I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I X X I I I I I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I X I Q I I X X X I X X X I I I I I I I C I I I I I I I I I I X X I I I I I I X X X I I I I I I X X I I I W I I X X I I X X I W I I X X I I I I I W I I I I I X X I [SEP]\n",
      "No.2,len:488\n",
      "texts:[CLS] Animated history of the US . Of course the cartoon is highly overs ##im ##plified and most critics consider it one of the weak ##est parts of the film . But it makes a valid claim which you ignore entirely : That the strategy to promote gun rights for white people and to out ##law gun possession by black people was a way to up ##hold racism without letting an openly terrorist organization like the K ##K ##K flourish . Did the 19th century N ##RA in the southern states promote gun rights for black people ? I highly doubt it . But if they didn ' t one of their functions was to continue the racism of the K ##K ##K . This is the key message of this part of the animation which is again being ignored by its critics . B ##uel ##l shooting in Flint . You write : F ##act : The little boy was the class th ##ug already suspended from school for stabbing another kid with a pencil and had fought with Kay ##la the day before . This characterization of a six - year - old as a pencil - stabbing th ##ug is exactly the kind of h ##yster ##ia that Moore ' s film warns against . It is the typical right - wing reaction which looks for simple answers that do not con ##tra ##dict the Republican minds ##et . The kid was a little bastard and the parents were involved in drugs - - case closed . But why do people deal with drugs ? Because it ' s so much fun to do so ? It is by now well documented that the CIA tolerate ##d crack sales in US cities to fund the operation of South American con ##tras It is equally well known that the so - called war on drugs begun under the Nixon administration is a failure which has cost hundreds of billion ##s and made America the world leader in prison population ( both in relative and absolute numbers ) . What type of the film is being discussed and what is on of the key messages ? Animated history of the US and one of the key messages is continuing the racist goals of the K ##K ##K . This is cartoon film depicted violence in elementary schools among six - year olds . . Animated key message : how gun rights are promoted for whites out ##law ##ed for blacks . It is an animated history of the US and one of the key messages is to continue the racism of the K ##K ##K . L ##I ##ve action key message : N ##RA in the southern states promote gun rights for black people . Documentary key message : the CIA fought cocaine smuggling into the U . S . in the ' 80s . [SEP]\n",
      "No.2,len:488\n",
      "lables:[CLS] P I I I I I P I I I I I I X X I I I I I I I I I X I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I I I I I I X X I I P I I I I X I I I I I I I I I I I P I I I I P I I I X X I I I I I I I I I I I I X X I P I I I I I I I I I I I I I I I I I I I P X X I I I I P I X I X X I I I I I I I X I I I I I I I I I I I I I I I I X I I I I P I I I I X X X X I I I X X I X I I I I I I X X I I X X I I I I P I I I I X X I I I I I I I I I I X X I I I X I P I I I I I I I I I I I I I X I I I P I I I I I I I P I X X I I I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I X X I I I I I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I X I Q I I I I I I I I I I I I I I I I C I I I I I I I I I I I I I I I I I I X X I W I I I I I I I I I I X X I I X C I I X I I I I I I I I X X I I I C I I I I I I I I I I I I I I I I I I I I I X X I W X X I I I X I X I I I I I I I I I I I W I I X I I I I I I I I I X X I I I X I [SEP]\n",
      "No.3,len:446\n",
      "texts:[CLS] Animated history of the US . Of course the cartoon is highly overs ##im ##plified and most critics consider it one of the weak ##est parts of the film . But it makes a valid claim which you ignore entirely : That the strategy to promote gun rights for white people and to out ##law gun possession by black people was a way to up ##hold racism without letting an openly terrorist organization like the K ##K ##K flourish . Did the 19th century N ##RA in the southern states promote gun rights for black people ? I highly doubt it . But if they didn ' t one of their functions was to continue the racism of the K ##K ##K . This is the key message of this part of the animation which is again being ignored by its critics . B ##uel ##l shooting in Flint . You write : F ##act : The little boy was the class th ##ug already suspended from school for stabbing another kid with a pencil and had fought with Kay ##la the day before . This characterization of a six - year - old as a pencil - stabbing th ##ug is exactly the kind of h ##yster ##ia that Moore ' s film warns against . It is the typical right - wing reaction which looks for simple answers that do not con ##tra ##dict the Republican minds ##et . The kid was a little bastard and the parents were involved in drugs - - case closed . But why do people deal with drugs ? Because it ' s so much fun to do so ? It is by now well documented that the CIA tolerate ##d crack sales in US cities to fund the operation of South American con ##tras It is equally well known that the so - called war on drugs begun under the Nixon administration is a failure which has cost hundreds of billion ##s and made America the world leader in prison population ( both in relative and absolute numbers ) . Which type of rights are being discussed and promoted by which group ? Under ##cover drug operations by the N ##RA . C ##rack ##down on crack dealing by the K ##K ##K . T ##hr right to use drugs if one sees fit was discussed by the Legal Aid Society . Disc ##uss ##ion of a strategy to promote gun rights for white people was discussed by N ##RS and K ##K ##K . The discussion of gun rights being promoted by the N ##RA . Gun rights promoted by the N ##RA . [SEP]\n",
      "No.3,len:446\n",
      "lables:[CLS] P I I I I I P I I I I I I X X I I I I I I I I I X I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I I I I I I X X I I P I I I I X I I I I I I I I I I I P I I I I P I I I X X I I I I I I I I I I I I X X I P I I I I I I I I I I I I I I I I I I I P X X I I I I P I X I X X I I I I I I I X I I I I I I I I I I I I I I I I X I I I I P I I I I X X X X I I I X X I X I I I I I I X X I I X X I I I I P I I I I X X I I I I I I I I I I X X I I I X I P I I I I I I I I I I I I I X I I I P I I I I I I I P I X X I I I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I X X I I I I I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I X I Q I I I I I I I I I I I I W X I I I I I X I W X X I I I I I I X X I W X I I I I I I I I I I I I I I I I C X X I I I I I I I I I I I I I I X I I X X I C I I I I I I I I I X I C I I I I I X I [SEP]\n",
      "No.4,len:386\n",
      "texts:[CLS] Animated history of the US . Of course the cartoon is highly overs ##im ##plified and most critics consider it one of the weak ##est parts of the film . But it makes a valid claim which you ignore entirely : That the strategy to promote gun rights for white people and to out ##law gun possession by black people was a way to up ##hold racism without letting an openly terrorist organization like the K ##K ##K flourish . Did the 19th century N ##RA in the southern states promote gun rights for black people ? I highly doubt it . But if they didn ' t one of their functions was to continue the racism of the K ##K ##K . This is the key message of this part of the animation which is again being ignored by its critics . B ##uel ##l shooting in Flint . You write : F ##act : The little boy was the class th ##ug already suspended from school for stabbing another kid with a pencil and had fought with Kay ##la the day before . This characterization of a six - year - old as a pencil - stabbing th ##ug is exactly the kind of h ##yster ##ia that Moore ' s film warns against . It is the typical right - wing reaction which looks for simple answers that do not con ##tra ##dict the Republican minds ##et . The kid was a little bastard and the parents were involved in drugs - - case closed . But why do people deal with drugs ? Because it ' s so much fun to do so ? It is by now well documented that the CIA tolerate ##d crack sales in US cities to fund the operation of South American con ##tras It is equally well known that the so - called war on drugs begun under the Nixon administration is a failure which has cost hundreds of billion ##s and made America the world leader in prison population ( both in relative and absolute numbers ) . In the author ' s mind which characterization of the B ##uel ##l school shooter is more appropriate ? T ##hu ##g or Ba ##star ##d ? Ba ##star ##d . T ##hu ##g . [SEP]\n",
      "No.4,len:386\n",
      "lables:[CLS] P I I I I I P I I I I I I X X I I I I I I I I I X I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I I I I I I X X I I P I I I I X I I I I I I I I I I I P I I I I P I I I X X I I I I I I I I I I I I X X I P I I I I I I I I I I I I I I I I I I I P X X I I I I P I X I X X I I I I I I I X I I I I I I I I I I I I I I I I X I I I I P I I I I X X X X I I I X X I X I I I I I I X X I I X X I I I I P I I I I X X I I I I I I I I I I X X I I I X I P I I I I I I I I I I I I I X I I I P I I I I I I I P I X X I I I I I I I P I I I I I I I I I X I I I I I I I I I I I I I X I I I I I I I I X X I I I I I I I I I I I I I I I I I X I I I I I I I I I I X I I I I I X I Q I I X X I I I I I I X X I I I I I I I X X I I X X I C X X I W X X I [SEP]\n"
     ]
    }
   ],
   "source": [
    "tokenized_texts = []\n",
    "word_piece_labels = []\n",
    "i_inc = 0\n",
    "for word_list,label in (zip(sentences,labels)):\n",
    "    temp_lable = []\n",
    "    temp_token = []\n",
    "    \n",
    "    # Add [CLS] at the front \n",
    "    temp_lable.append('[CLS]')\n",
    "    temp_token.append('[CLS]')\n",
    "    \n",
    "    for word,lab in zip(word_list,label):\n",
    "        token_list = tokenizer.tokenize(word)\n",
    "        for m,token in enumerate(token_list):\n",
    "            temp_token.append(token)\n",
    "            if m==0:\n",
    "                temp_lable.append(lab)\n",
    "            else:\n",
    "                temp_lable.append('X')  \n",
    "                \n",
    "    # Add [SEP] at the end\n",
    "    temp_lable.append('[SEP]')\n",
    "    temp_token.append('[SEP]')\n",
    "    \n",
    "    tokenized_texts.append(temp_token)\n",
    "    word_piece_labels.append(temp_lable)\n",
    "    \n",
    "    if 5 > i_inc:\n",
    "        print(\"No.%d,len:%d\"%(i_inc,len(temp_token)))\n",
    "        print(\"texts:%s\"%(\" \".join(temp_token)))\n",
    "        print(\"No.%d,len:%d\"%(i_inc,len(temp_lable)))\n",
    "        print(\"lables:%s\"%(\" \".join(temp_lable)))\n",
    "    i_inc +=1\n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "36NKtzIoUi9y"
   },
   "source": [
    "### Setting-up token embedding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Zf0AGkN1Ui9y"
   },
   "source": [
    "Pad or trim the text and label to fit the need for max len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 561
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 100318,
     "status": "ok",
     "timestamp": 1588562038890,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "LgI58czOUi9z",
    "outputId": "711158c3-ec04-4dd7-80d8-98e396914b85"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[  101 24238  1607  1104  1103  1646   119  2096  1736  1103 11540  1110\n",
      "  3023 17074  4060 18580  1105  1211  4217  4615  1122  1141  1104  1103\n",
      "  4780  2556  2192  1104  1103  1273   119  1252  1122  2228   170  9221\n",
      "  3548  1134  1128  8429  3665   131  1337  1103  5564  1106  4609  2560\n",
      "  2266  1111  1653  1234  1105  1106  1149  9598  2560  6224  1118  1602\n",
      "  1234  1108   170  1236  1106  1146  8678 16654  1443  5074  1126  9990\n",
      "  9640  2369  1176  1103   148  2428  2428 27760   119  2966  1103  2835\n",
      "  1432   151  9664  1107  1103  2359  2231  4609  2560  2266  1111  1602\n",
      "  1234   136   146  3023  4095  1122   119  1252  1191  1152  1238   112\n",
      "   189  1141  1104  1147  4226  1108  1106  2760  1103 16654  1104  1103\n",
      "   148  2428  2428   119  1188  1110  1103  2501  3802  1104  1142  1226\n",
      "  1104  1103  8794  1134  1110  1254  1217  5794  1118  1157  4217   119\n",
      "   139 24741  1233  4598  1107 17741   119  1192  3593   131   143 11179\n",
      "   131  1109  1376  2298  1108  1103  1705 24438  9610  1640  6232  1121\n",
      "  1278  1111 24728  1330  5102  1114   170 16372  1105  1125  3214  1114\n",
      " 11247  1742  1103  1285  1196   119  1188 27419  1104   170  1565   118\n",
      "  1214   118  1385  1112   170 16372   118 24728 24438  9610  1110  2839\n",
      "  1103  1912  1104   177 21878  1465  1115  4673   112   188  1273 21310\n",
      "  1222   119  1135  1110  1103  4701  1268   118  3092  3943  1134  2736\n",
      "  1111  3014  6615  1115  1202  1136 14255  4487 28113  1103  3215 10089\n",
      "  2105   119  1109  5102  1108   170  1376  8735  1105  1103  2153  1127\n",
      "  2017  1107  5557   118   118  1692  1804   119  1252  1725  1202  1234\n",
      "  2239  1114  5557   136  2279  1122   112   188  1177  1277  4106  1106\n",
      "  1202  1177   136  1135  1110  1118  1208  1218  8510  1115  1103  9878\n",
      " 21073  1181  8672  3813  1107  1646  3038  1106  5841  1103  2805  1104\n",
      "  1375  1237 14255 25352  1135  1110  7808  1218  1227  1115  1103  1177\n",
      "   118  1270  1594  1113  5557  4972  1223  1103 11302  3469  1110   170\n",
      "  4290  1134  1144  2616  5229  1104  3775  1116  1105  1189  1738  1103\n",
      "  1362  2301  1107  3315  1416   113  1241  1107  5236  1105  7846  2849\n",
      "   114   119  7187  1103  2351  3548  1103  6608  2441  3802  1110  1115\n",
      "  1103   151  9664  1146  8678  1116 16654   136  2160   119  3725  8678\n",
      "  1105  2760   119  1302   119   102     0     0     0     0     0     0]\n"
     ]
    }
   ],
   "source": [
    "# Make text token into id\n",
    "input_ids = pad_sequences([tokenizer.convert_tokens_to_ids(txt) for txt in tokenized_texts],\n",
    "                          maxlen=max_len, dtype=\"long\", truncating=\"post\", padding=\"post\")\n",
    "print(input_ids[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 100785,
     "status": "ok",
     "timestamp": 1588562039369,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "1hT4lASzUi91",
    "outputId": "b25cf780-01f4-4ede-9b3b-54f737a3a88d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[6 0 3 3 3 3 3 0 3 3 3 3 3 3 5 5 3 3 3 3 3 3 3 3 3 5 3 3 3 3 3 0 3 3 3 3 3\n",
      " 3 3 3 3 5 3 3 3 3 3 3 3 3 3 3 3 3 3 5 3 3 3 3 3 3 3 3 3 3 5 3 3 3 3 3 3 3\n",
      " 3 3 3 5 5 3 3 0 3 3 3 3 5 3 3 3 3 3 3 3 3 3 3 3 0 3 3 3 3 0 3 3 3 5 5 3 3\n",
      " 3 3 3 3 3 3 3 3 3 3 5 5 3 0 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 0 5 5 3\n",
      " 3 3 3 0 3 5 3 5 5 3 3 3 3 3 3 3 5 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 5 3 3 3\n",
      " 3 0 3 3 3 3 5 5 5 5 3 3 3 5 5 3 5 3 3 3 3 3 3 5 5 3 3 5 5 3 3 3 3 0 3 3 3\n",
      " 3 5 5 3 3 3 3 3 3 3 3 3 3 5 5 3 3 3 5 3 0 3 3 3 3 3 3 3 3 3 3 3 3 3 5 3 3\n",
      " 3 0 3 3 3 3 3 3 3 0 3 5 5 3 3 3 3 3 3 3 0 3 3 3 3 3 3 3 3 3 5 3 3 3 3 3 3\n",
      " 3 3 3 3 3 3 3 5 3 3 3 3 3 3 3 3 5 5 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 3 5 3\n",
      " 3 3 3 3 3 3 3 3 3 5 3 3 3 3 3 5 3 1 3 3 3 3 3 3 3 3 3 3 3 5 3 5 5 3 3 2 3\n",
      " 2 5 3 3 3 4 3 7 4 4 4 4 4 4]\n"
     ]
    }
   ],
   "source": [
    "# Make label into id, pad with \"W\" meaning others/wrong\n",
    "# Note - Replaced \"O\" -> \"W\" (wrong)\n",
    "tags = pad_sequences([[tag2idx.get(l) for l in lab] for lab in word_piece_labels],\n",
    "                     maxlen=max_len, value=tag2idx[\"W\"], padding=\"post\",\n",
    "                     dtype=\"long\", truncating=\"post\")\n",
    "print(tags[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Q90_O7JVUi96"
   },
   "source": [
    "### Setting-up mask word embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "E03rTe7kUi96"
   },
   "outputs": [],
   "source": [
    "# For fine tune of predict, with token mask is 1,pad token is 0\n",
    "attention_masks = [[int(i>0) for i in ii] for ii in input_ids]\n",
    "attention_masks[0];"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "mg7m0PyoUi99"
   },
   "source": [
    "### Setting-up segment embedding(Analysis- for sequance tagging task, it's not necessary to make this embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "M2mMlimtUi99"
   },
   "outputs": [],
   "source": [
    "# Since only one sentence, all the segment set to 0\n",
    "segment_ids = [[0] * len(input_id) for input_id in input_ids]\n",
    "segment_ids[0];"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "s1fH7smpci0t"
   },
   "outputs": [],
   "source": [
    "# print(segment_ids) # ERROR - IOPub data rate exceeded. (TOO MUCH!)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "HXfDgtzuUi-F"
   },
   "source": [
    "## Load TRAIN and VALIDATION sets"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "z37-OiuHUi-I"
   },
   "source": [
    "**Split all data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ArF7Fq8fUi-J"
   },
   "outputs": [],
   "source": [
    "tr_inputs, tr_tags, tr_masks, tr_segs = input_ids, tags, attention_masks, segment_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 102163,
     "status": "ok",
     "timestamp": 1588562040795,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "lYcZ4p6lUi-O",
    "outputId": "dc48152f-174f-4201-9334-98860ad26967"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5131, 5131)"
      ]
     },
     "execution_count": 41,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tr_inputs),len(tr_segs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 136
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 102151,
     "status": "ok",
     "timestamp": 1588562040795,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "v38ZLJ0-tKDp",
    "outputId": "fcbcbfa5-5eb5-4f8c-d037-38db88be6fde"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[  101 24238  1607 ...     0     0     0]\n",
      " [  101 24238  1607 ...  1111  1602  1234]\n",
      " [  101 24238  1607 ... 18848  2513  1104]\n",
      " ...\n",
      " [  101   158   119 ...  1523  1113  9170]\n",
      " [  101   158   119 ...  1523  1113  9170]\n",
      " [  101   158   119 ...  1523  1113  9170]]\n"
     ]
    }
   ],
   "source": [
    "print(tr_inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "wiMXgofyUi-T"
   },
   "source": [
    "**Set data into tensor**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "JXbuWgNaUi-T"
   },
   "source": [
    "NOTE - Not recommend tensor.to(device) at this process, since it will run out of GPU memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_LEsSzo2Ui-U"
   },
   "outputs": [],
   "source": [
    "tr_inputs = torch.tensor(tr_inputs)\n",
    "tr_tags = torch.tensor(tr_tags)\n",
    "tr_masks = torch.tensor(tr_masks)\n",
    "tr_segs = torch.tensor(tr_segs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7nR_mX0aUi-a"
   },
   "source": [
    "**Put data into data loader**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "-iD8Yay_Ui-b"
   },
   "outputs": [],
   "source": [
    "# Set batch num\n",
    "batch_num = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uFWiVf50Ui-e"
   },
   "outputs": [],
   "source": [
    "# Only set token embedding, attention embedding, no segment embedding\n",
    "train_data = TensorDataset(tr_inputs, tr_masks, tr_tags)\n",
    "train_sampler = RandomSampler(train_data)\n",
    "# Drop last can make batch training better for the last one\n",
    "train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_num,drop_last=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "C4JdNTXXUi-i"
   },
   "source": [
    "## Train model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "5bIB2t48Ui-i"
   },
   "source": [
    "- Pre-requisite: Downloading model files in GDrive\n",
    "- Model used - BERT-base-cased\n",
    "- pytorch_model.bin: [pytorch_model.bin](https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin)\n",
    "- config.json: [config.json](https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json)    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tqQuPl0yUi-j"
   },
   "source": [
    "**Loading BERT model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "CpnX05xiUi-j"
   },
   "outputs": [],
   "source": [
    "# In this folder, contain model confg(json) and model weight(bin) files\n",
    "# pytorch_model.bin, download from: https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin\n",
    "# config.json, downlaod from: https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json\n",
    "model_file_address = PARENT_DIR + \"/models\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 103424,
     "status": "ok",
     "timestamp": 1588562042109,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "Gtov4DiFfinI",
    "outputId": "c55f4905-72b3-4de9-9fb5-5d4f7d368014"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "config.json  pytorch_model.bin\n"
     ]
    }
   ],
   "source": [
    "!ls \"/content/gdrive/My Drive/MultiRC_NER/models\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Jr7114QmUi-k"
   },
   "outputs": [],
   "source": [
    "# Will load config and weight with from_pretrained()\n",
    "model = BertForTokenClassification.from_pretrained(model_file_address,num_labels=len(tag2idx))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Z55qnZFjUi-m"
   },
   "outputs": [],
   "source": [
    "model;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "nX7OqEG4Ui-o"
   },
   "outputs": [],
   "source": [
    "# Set model to GPU,if you are using GPU machine\n",
    "model.cuda();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "rXgBBGaYUi-p"
   },
   "outputs": [],
   "source": [
    "# Add multi GPU support\n",
    "#if n_gpu >1:\n",
    " #   model = torch.nn.DataParallel(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ToxEzMs9Ui-r"
   },
   "outputs": [],
   "source": [
    "# Set epoch and grad max num\n",
    "epochs = 5\n",
    "max_grad_norm = 1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "vDeMuG_qUi-t"
   },
   "outputs": [],
   "source": [
    "# Cacluate train optimiazaion num\n",
    "num_train_optimization_steps = int( math.ceil(len(tr_inputs) / batch_num) / 1) * epochs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FJ1liMnDUi-z"
   },
   "source": [
    "### Setting-up fine tuning method"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "43hTJyHJUi-z"
   },
   "source": [
    "**Manual optimizer**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "OvuJtdG7Ui-z"
   },
   "outputs": [],
   "source": [
    "# True: fine tuning all the layers \n",
    "# False: only fine tuning the classifier layers\n",
    "FULL_FINETUNING = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Iv1pdD9YUi-2"
   },
   "outputs": [],
   "source": [
    "if FULL_FINETUNING:\n",
    "    # Fine tune model all layer parameters\n",
    "    param_optimizer = list(model.named_parameters())\n",
    "    no_decay = ['bias', 'gamma', 'beta']\n",
    "    optimizer_grouped_parameters = [\n",
    "        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],\n",
    "         'weight_decay_rate': 0.01},\n",
    "        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],\n",
    "         'weight_decay_rate': 0.0}\n",
    "    ]\n",
    "else:\n",
    "    # Only fine tune classifier parameters\n",
    "    param_optimizer = list(model.classifier.named_parameters()) \n",
    "    optimizer_grouped_parameters = [{\"params\": [p for n, p in param_optimizer]}]\n",
    "optimizer = AdamW(optimizer_grouped_parameters, lr=3e-5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7ZybMSrOUi-8"
   },
   "source": [
    "### Fine-tuning model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Lz3zEEi_Ui-8"
   },
   "outputs": [],
   "source": [
    "# TRAIN loop\n",
    "model.train();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "mS_C97Sou9KF"
   },
   "outputs": [],
   "source": [
    "# Check logs for crash\n",
    "#!cat /var/log/colab-jupyter.log"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 255
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1125990,
     "status": "ok",
     "timestamp": 1588563064766,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "DZ95271QUi--",
    "outputId": "00a5beea-8f5a-4da6-85ca-a6b96bad618f",
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:   0%|          | 0/5 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "***** Running training *****\n",
      "  Num examples = 5131\n",
      "  Batch size = 16\n",
      "  Num steps = 1605\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/pytorch/torch/csrc/utils/python_arg_parser.cpp:756: UserWarning: This overload of add_ is deprecated:\n",
      "\tadd_(Number alpha, Tensor other)\n",
      "Consider using one of the following signatures instead:\n",
      "\tadd_(Tensor other, *, Number alpha)\n",
      "Epoch:  20%|██        | 1/5 [03:20<13:20, 200.05s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.08483015493693528\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:  40%|████      | 2/5 [06:39<09:59, 199.88s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.009345573517930462\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:  60%|██████    | 3/5 [09:58<06:39, 199.72s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.007660632718761917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "Epoch:  80%|████████  | 4/5 [13:18<03:19, 199.59s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.006583411008614348\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 100%|██████████| 5/5 [16:37<00:00, 199.50s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train loss: 0.005635686671666917\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "print(\"***** Running training *****\")\n",
    "print(\"  Num examples = %d\"%(len(tr_inputs)))\n",
    "print(\"  Batch size = %d\"%(batch_num))\n",
    "print(\"  Num steps = %d\"%(num_train_optimization_steps))\n",
    "for _ in trange(epochs,desc=\"Epoch\"):\n",
    "    tr_loss = 0\n",
    "    nb_tr_examples, nb_tr_steps = 0, 0\n",
    "    for step, batch in enumerate(train_dataloader):\n",
    "        # add batch to gpu\n",
    "        batch = tuple(t.to(device) for t in batch)\n",
    "        b_input_ids, b_input_mask, b_labels = batch\n",
    "        \n",
    "        # forward pass\n",
    "        outputs = model(b_input_ids, token_type_ids=None,\n",
    "        attention_mask=b_input_mask, labels=b_labels)\n",
    "        loss, scores = outputs[:2]\n",
    "      #  if n_gpu>1:\n",
    "            # When multi gpu, average it\n",
    "       #     loss = loss.mean()\n",
    "        \n",
    "        # backward pass\n",
    "        loss.backward()\n",
    "        \n",
    "        # track train loss\n",
    "        tr_loss += loss.item()\n",
    "        nb_tr_examples += b_input_ids.size(0)\n",
    "        nb_tr_steps += 1\n",
    "        \n",
    "        # gradient clipping\n",
    "        torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)\n",
    "        \n",
    "        # update parameters\n",
    "        optimizer.step()\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "    # print train loss per epoch\n",
    "    print(\"Train loss: {}\".format(tr_loss/nb_tr_steps))\n",
    "        "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CHPIwbuGUi_C"
   },
   "source": [
    "## Save model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "YsUIIUCSUi_C"
   },
   "outputs": [],
   "source": [
    "# TODO: output/ => original data, output/sample/ => sampled data\n",
    "bert_out_address = PARENT_DIR + \"/output/trained_v2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "FpSNSkWtUi_E"
   },
   "outputs": [],
   "source": [
    "# Make dir if not exits\n",
    "if not os.path.exists(bert_out_address):\n",
    "        os.makedirs(bert_out_address)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "DdOW-B46Ui_I"
   },
   "outputs": [],
   "source": [
    "# Save a trained model, configuration and tokenizer\n",
    "model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "P2uAveVXUi_L"
   },
   "outputs": [],
   "source": [
    "# If we save using the predefined names, we can load using `from_pretrained`\n",
    "output_model_file = os.path.join(bert_out_address, \"pytorch_model.bin\")\n",
    "output_config_file = os.path.join(bert_out_address, \"config.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 1127765,
     "status": "ok",
     "timestamp": 1588563066585,
     "user": {
      "displayName": "Soujanya Ranganatha Bhat",
      "photoUrl": "https://lh3.googleusercontent.com/a-/AOh14GjYZYrqnugKHbgoo144GZ9rzmvTfTGIL9eFkBCz=s64",
      "userId": "15617339232293464832"
     },
     "user_tz": 420
    },
    "id": "GWx_l9fkUi_N",
    "outputId": "4b95767f-fd78-46b4-c3a2-7c46c6c07e52",
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('/content/gdrive/My Drive/MultiRC_NER/output/trained_v2/vocab.txt',)"
      ]
     },
     "execution_count": 63,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Save model into file\n",
    "torch.save(model_to_save.state_dict(), output_model_file)\n",
    "model_to_save.config.to_json_file(output_config_file)\n",
    "tokenizer.save_vocabulary(bert_out_address)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kdOCouaEzzNb"
   },
   "source": [
    "# ----------- END OF TRAINING -----------"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lZpyxO9_zjCS"
   },
   "source": [
    "# Refer to MultiRC-NER_eval note book for EVALUATIONS & ANALYSIS"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [
    "3yWP5SKUUi_S",
    "J3BVkBWLUi_Z"
   ],
   "machine_shape": "hm",
   "name": "MultiRC-NER.ipynb",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
